"""
Phase 7: Dynamics
==================
Tests temporal patterns in gross vs net positions:
  Table 11: Z → d_gross_assets, d_gross_liab, d_nfa (first-differenced)
  Table 12: 5-year lagged Z on key DVs
  Table 13: Pre/post-GFC split on gross positions
  Table 13c: Rate-mediated income balance structural break
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.4f}{stars(p)}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        if 'Z_' in name:
            sig = stars(gls.pvalues[i])
            print(f"    {name:25s} {gls.beta[i]:10.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title, key_vars=None):
    if not results:
        return

    lines = [f"# {title}\n"]

    if key_vars is None:
        key_vars = []
        for r in results:
            for k in r:
                if k.endswith('_coef'):
                    v = k.replace('_coef', '')
                    if v not in key_vars:
                        key_vars.append(v)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in key_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    for stat, key, fmt_str in [('Dep var', 'dep_var', '{}'), ('N', 'n_obs', '{}'),
                                ('R²', 'r_squared', '{:.4f}'),
                                ('Countries', 'n_countries', '{}')]:
        row = f"| {stat} |"
        for r in results:
            row += f" {fmt_str.format(r[key])} |"
        lines.append(row)

    lines.append("\n*Panel GLS with AR(1) errors. Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 7: DYNAMICS")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]

    # ══════════════════════════════════════════════════════════════════
    # TABLE 11: FIRST-DIFFERENCED (FLOWS)
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 11: FIRST-DIFFERENCED (FLOW PROXY)")
    print("=" * 50)

    results_t11 = []
    diff_vars = ['d_gross_assets_gdp', 'd_gross_liab_gdp', 'd_nfa_gdp']
    diff_vars = [v for v in diff_vars if v in df.columns and df[v].notna().sum() > 200]

    for dep in diff_vars:
        short = dep.replace('d_', 'Δ').replace('_gdp', '')
        r = run_gls(df, dep, DEMO_VARS + controls, f'Full: {short}')
        if r: results_t11.append(r)

    # Also run on CA for comparison
    r = run_gls(df, 'ca_gdp', DEMO_VARS + controls, 'Full: CA/GDP')
    if r: results_t11.append(r)

    write_table(results_t11, "dynamics_first_diff.md",
                "Table 11: Demographics and Changes in Gross Positions (First Differences)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 12: 5-YEAR LAGGED DEMOGRAPHICS
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 12: 5-YEAR LAGGED DEMOGRAPHICS")
    print("=" * 50)

    lag_vars = ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5']
    lag_vars = [v for v in lag_vars if v in df.columns and df[v].notna().sum() > 200]

    if lag_vars:
        results_t12 = []
        key_dvs = ['ca_gdp', 'gross_assets_gdp', 'income_balance_gdp', 'debt_assets_gdp']
        key_dvs = [v for v in key_dvs if v in df.columns and df[v].notna().sum() > 200]

        for dep in key_dvs:
            short = dep.replace('_gdp', '')
            r = run_gls(df, dep, lag_vars + controls, f'Lag5: {short}')
            if r: results_t12.append(r)

        # Side-by-side with contemporaneous for CA
        r_contemp = run_gls(df, 'ca_gdp', DEMO_VARS + controls, 'Contemp: CA')
        if r_contemp: results_t12.append(r_contemp)

        write_table(results_t12, "dynamics_lagged.md",
                    "Table 12: 5-Year Lagged Demographics on Key Dependent Variables",
                    key_vars=lag_vars + DEMO_VARS + controls)
    else:
        print("  No lagged Z variables available")

    # ══════════════════════════════════════════════════════════════════
    # TABLE 13: PRE/POST-GFC SPLIT ON GROSS POSITIONS
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 13: PRE/POST-GFC STRUCTURAL BREAK")
    print("=" * 50)

    results_t13 = []
    gross_dvs = ['gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi', 'ca_gdp']
    gross_dvs = [v for v in gross_dvs if v in df.columns]

    for period, mask, label in [('Pre-GFC', df['year'] <= 2007, 'Pre-GFC'),
                                 ('Post-GFC', df['year'] >= 2010, 'Post-GFC')]:
        sub = df[mask].copy()
        for dep in gross_dvs:
            short = dep.replace('_gdp', '')
            r = run_gls(sub, dep, DEMO_VARS + controls, f'{label}: {short}')
            if r: results_t13.append(r)

    write_table(results_t13, "dynamics_gfc_split.md",
                "Table 13: Pre vs Post-GFC — Gross Positions",
                key_vars=DEMO_VARS)

    # ══════════════════════════════════════════════════════════════════
    # CHOW TEST: Structural break on income balance
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("CHOW TEST: STRUCTURAL BREAK ON INCOME BALANCE")
    print("=" * 50)

    if 'income_balance_gdp' in df.columns:
        from scipy import stats as scipy_stats

        dep = 'income_balance_gdp'
        x_vars = DEMO_VARS + controls
        cols = [dep] + x_vars + ['iso3', 'year']

        # Pooled
        pooled = df[cols].dropna()
        gls_p = PanelGLS()
        gls_p.fit(pooled[dep].values, pooled[x_vars].values,
                  pooled['iso3'].values, pooled['year'].values)
        rss_p = np.sum((pooled[dep].values - pooled[x_vars].values @ gls_p.beta) ** 2)
        n_p = gls_p.n_obs

        # Pre-GFC
        pre = pooled[pooled['year'] <= 2007]
        gls_pre = PanelGLS()
        gls_pre.fit(pre[dep].values, pre[x_vars].values,
                    pre['iso3'].values, pre['year'].values)
        rss_pre = np.sum((pre[dep].values - pre[x_vars].values @ gls_pre.beta) ** 2)
        n_pre = gls_pre.n_obs

        # Post-GFC
        post = pooled[pooled['year'] >= 2010]
        gls_post = PanelGLS()
        gls_post.fit(post[dep].values, post[x_vars].values,
                     post['iso3'].values, post['year'].values)
        rss_post = np.sum((post[dep].values - post[x_vars].values @ gls_post.beta) ** 2)
        n_post = gls_post.n_obs

        k = len(x_vars)
        chow_num = (rss_p - rss_pre - rss_post) / k
        chow_den = (rss_pre + rss_post) / (n_pre + n_post - 2 * k)
        chow_f = chow_num / chow_den if chow_den > 0 else np.nan
        chow_p = 1 - scipy_stats.f.cdf(chow_f, k, n_pre + n_post - 2 * k) if not np.isnan(chow_f) else np.nan

        print(f"  Income balance Chow test:")
        print(f"    Pooled RSS={rss_p:.2f}, N={n_p}")
        print(f"    Pre-GFC RSS={rss_pre:.2f}, N={n_pre}, Z₁={gls_pre.beta[0]:.2f}")
        print(f"    Post-GFC RSS={rss_post:.2f}, N={n_post}, Z₁={gls_post.beta[0]:.2f}")
        print(f"    Chow F={chow_f:.3f}, p={chow_p:.4f}")
        sig = '***' if chow_p < 0.01 else ('**' if chow_p < 0.05 else ('*' if chow_p < 0.1 else ''))
        print(f"    Structural break: {'CONFIRMED' if chow_p < 0.05 else 'NOT confirmed'} {sig}")

        # Write Chow test result
        chow_lines = ["# Chow Test: Structural Break on Income Balance\n"]
        chow_lines.append("| Statistic | Value |")
        chow_lines.append("|:---|---:|")
        chow_lines.append(f"| Pooled N | {n_p} |")
        chow_lines.append(f"| Pre-GFC N | {n_pre} |")
        chow_lines.append(f"| Post-GFC N | {n_post} |")
        chow_lines.append(f"| Pre-GFC Z₁ | {gls_pre.beta[0]:.2f} |")
        chow_lines.append(f"| Post-GFC Z₁ | {gls_post.beta[0]:.2f} |")
        chow_lines.append(f"| Chow F-statistic | {chow_f:.3f} |")
        chow_lines.append(f"| Chow p-value | {chow_p:.4f} |")
        chow_lines.append(f"| Structural break | {'Yes' if chow_p < 0.05 else 'No'}{sig} |")
        chow_lines.append("\n*Chow (1960) test for structural break at 2008-2009.*")
        chow_lines.append("*Break point excludes 2008-2009 transition years.*")

        (OUT_TABLES / "chow_test_income_balance.md").write_text('\n'.join(chow_lines))
        print(f"  Saved: chow_test_income_balance.md")

    # ══════════════════════════════════════════════════════════════════
    # TABLE 13c: RATE-MEDIATED INCOME BALANCE BREAK
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 13c: RATE-MEDIATED INCOME BALANCE BREAK")
    print("=" * 50)

    if 'income_balance_gdp' in df.columns and 'real_bond_10y' in df.columns:
        results_t13c = []
        dep = 'income_balance_gdp'
        rate_controls = ['real_bond_10y']

        # Col 1: Baseline income balance (no rate control)
        r = run_gls(df, dep, DEMO_VARS + controls, 'Full: baseline')
        if r: results_t13c.append(r)

        # Col 2: With rate control
        r = run_gls(df, dep, DEMO_VARS + controls + rate_controls, 'Full: +rates')
        if r: results_t13c.append(r)

        # Col 3: Pre-GFC baseline
        pre = df[df['year'] <= 2007].copy()
        r = run_gls(pre, dep, DEMO_VARS + controls, 'Pre-GFC: baseline')
        if r: results_t13c.append(r)

        # Col 4: Pre-GFC with rate control
        r = run_gls(pre, dep, DEMO_VARS + controls + rate_controls, 'Pre-GFC: +rates')
        if r: results_t13c.append(r)

        # Col 5: Post-GFC baseline
        post = df[df['year'] >= 2010].copy()
        r = run_gls(post, dep, DEMO_VARS + controls, 'Post-GFC: baseline')
        if r: results_t13c.append(r)

        # Col 6: Post-GFC with rate control
        r = run_gls(post, dep, DEMO_VARS + controls + rate_controls, 'Post-GFC: +rates')
        if r: results_t13c.append(r)

        write_table(results_t13c, "dynamics_rate_mediated_break.md",
                    "Table 13c: Rate-Mediated Income Balance Structural Break",
                    key_vars=DEMO_VARS + controls + rate_controls)

        # Interpretation
        if len(results_t13c) >= 6:
            z1_pre_base = results_t13c[2].get('Z_1_coef', float('nan'))
            z1_pre_rate = results_t13c[3].get('Z_1_coef', float('nan'))
            z1_post_base = results_t13c[4].get('Z_1_coef', float('nan'))
            z1_post_rate = results_t13c[5].get('Z_1_coef', float('nan'))
            print(f"\n  Pre-GFC:  Z₁ baseline={z1_pre_base:.1f} → +rates={z1_pre_rate:.1f}")
            print(f"  Post-GFC: Z₁ baseline={z1_post_base:.1f} → +rates={z1_post_rate:.1f}")
            if abs(z1_pre_base) > 0.01:
                attenuation = (z1_pre_rate - z1_pre_base) / z1_pre_base * 100
                print(f"  Pre-GFC attenuation from rates: {attenuation:.0f}%")
    else:
        missing = [v for v in ['income_balance_gdp', 'real_bond_10y']
                   if v not in df.columns]
        print(f"  SKIP: missing columns {missing}")

    print("\n" + "=" * 70)
    print("PHASE 7 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
